Skip to content

[PyTorch] Add an API restore from function context to ensure tensors are detached#2772

Merged
ksivaman merged 2 commits intoNVIDIA:mainfrom
kainzhong:fix/change_restore_api
Mar 19, 2026
Merged

[PyTorch] Add an API restore from function context to ensure tensors are detached#2772
ksivaman merged 2 commits intoNVIDIA:mainfrom
kainzhong:fix/change_restore_api

Conversation

@kainzhong
Copy link
Collaborator

Description

In quantizerd_tensor.py we only had restore_from_saved, and its typical usage is:

tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)

However, after calling this API it's common for people to forget to detach tensor_objects from ctx, causing it to be kept by the function context and the tensor (along with the allocated memory) can only be released until the next iteration when the context is destroyed (see #2750).

By adding this new API to restore from context directly (and discourage using restore_from_saved if you are restoring from a function context), it will delete the reference to tensor objects after restoring and ensure the memory is freed.

Fixes # (issue)

For example, with FP8 quantization and GroupedLinear:

Before the fix After the fix
image image

Notice the time when the selected memory section is released.

Test script:

test.py

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add restore_from_func_ctx which will delete tensor objects for users after restoring saved tensors

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 17, 2026

Greptile Summary

This PR introduces restore_from_func_ctx(ctx) as a thin wrapper around restore_from_saved that automatically sets ctx.tensor_objects = None after restoring, ensuring the function context does not retain a reference to quantized tensor metadata across iterations. All 9 call sites across the codebase are migrated to the new API.

  • Memory leak fix in 4 sites (grouped_linear, FusedAttnFunc, AttnFuncWithCPAndKVP2P, AttnFuncWithCPAndQKVOA2A, and the test's _custom_mha_fp8): these previously called restore_from_saved without ever nullifying ctx.tensor_objects, so tensor metadata (and associated GPU memory) was held alive until the next iteration destroyed the autograd context.
  • Deduplication in the remaining 4 sites (linear, layernorm_linear, layernorm_mlp._recompute, fuser): these previously called restore_from_saved followed by a manual ctx.tensor_objects = None; that two-step pattern is now a single call.
  • restore_from_func_ctx is correctly exported from both transformer_engine.pytorch and transformer_engine.pytorch.tensor.
  • The return_saved_tensors parameter lacks its : bool type annotation, which is inconsistent with the matching parameter in restore_from_saved.

Confidence Score: 4/5

  • Safe to merge; all migrations are semantically equivalent and the new API is a strict improvement over the previous two-step pattern.
  • The logic is correct across all call sites, the guard in restore_from_func_ctx handles both the missing-attribute and double-call cases, and the PR fixes latent memory leaks at several sites that were previously not nullifying ctx.tensor_objects. The only minor issue is a missing : bool type annotation on the return_saved_tensors parameter. Score of 4 rather than 5 only because no new automated tests are added to validate the memory-freeing behavior.
  • No files require special attention; all changes are straightforward API migrations.

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantized_tensor.py Introduces restore_from_func_ctx, which wraps restore_from_saved and automatically nullifies ctx.tensor_objects after restoration. The guard covers both the missing-attribute and the already-called (None) case. Minor: return_saved_tensors parameter is missing its : bool annotation.
transformer_engine/pytorch/module/grouped_linear.py Migrates to restore_from_func_ctx. Previously did NOT nullify ctx.tensor_objects after restoration — this PR fixes that latent memory leak.
transformer_engine/pytorch/module/linear.py Migrates to restore_from_func_ctx, removing the previously explicit ctx.tensor_objects = None and the intermediate saved_tensors = ctx.saved_tensors variable. Semantically equivalent.
transformer_engine/pytorch/module/layernorm_linear.py Migrates to restore_from_func_ctx, removing the explicit ctx.tensor_objects = None. Semantically equivalent.
transformer_engine/pytorch/module/layernorm_mlp.py Migrates _recompute to use restore_from_func_ctx. The nullification now happens before the optional recomputed _forward call, which is the same order as the previous explicit ctx.tensor_objects = None. No behavioral change.
transformer_engine/pytorch/ops/fuser.py Migrates _OperationFuserAutogradFunction.backward to restore_from_func_ctx, removing the previously explicit func_ctx.tensor_objects = None.
transformer_engine/pytorch/attention/dot_product_attention/backends.py Migrates FusedAttnFunc.backward to restore_from_func_ctx. Previously did NOT nullify ctx.tensor_objects — memory leak fixed.
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Migrates two autograd function backward passes (AttnFuncWithCPAndKVP2P, AttnFuncWithCPAndQKVOA2A) to restore_from_func_ctx. Both previously did NOT nullify ctx.tensor_objects — memory leak fixed in both.
transformer_engine/pytorch/init.py Exports restore_from_func_ctx from the top-level package. Straightforward addition.
transformer_engine/pytorch/tensor/init.py Re-exports restore_from_func_ctx from the tensor subpackage and adds it to __all__. Consistent with how restore_from_saved is exposed.
tests/pytorch/attention/test_attention.py Updates the _custom_mha_fp8 test backward pass to use restore_from_func_ctx, removing the two-step pattern. Previously did not nullify ctx.tensor_objects.

Sequence Diagram

sequenceDiagram
    participant Backward as backward(ctx)
    participant RRFC as restore_from_func_ctx
    participant RFS as restore_from_saved
    participant Ctx as ctx

    Backward->>RRFC: "restore_from_func_ctx(ctx)"
    RRFC->>Ctx: "read ctx.tensor_objects"
    RRFC->>Ctx: "read ctx.saved_tensors"
    RRFC->>RFS: "restore_from_saved(tensor_objects, saved_tensors)"
    RFS-->>RRFC: "restored tensor list"
    RRFC->>Ctx: "ctx.tensor_objects = None"
    RRFC-->>Backward: "restored tensor list"
Loading

Last reviewed commit: "Merge branch 'main' ..."

Comment on lines +187 to +188
if not hasattr(ctx, "tensor_objects"):
raise AttributeError("ctx must have .tensor_objects to restore saved tensors")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 hasattr guard misses the None case after first call

After the first successful call to restore_from_func_ctx, ctx.tensor_objects is set to None. If the function is ever called a second time on the same context (e.g., due to programming error or future refactoring), hasattr(ctx, "tensor_objects") will still return True (the attribute exists, it's just None), and the code will proceed to call restore_from_saved(None, ctx.saved_tensors). This causes an unhelpful TypeError: 'NoneType' object is not iterable deep inside restore_from_saved rather than a clear AttributeError here.

The guard should also check for None:

Suggested change
if not hasattr(ctx, "tensor_objects"):
raise AttributeError("ctx must have .tensor_objects to restore saved tensors")
if not hasattr(ctx, "tensor_objects") or ctx.tensor_objects is None:
raise AttributeError("ctx must have .tensor_objects to restore saved tensors")

list[Optional[torch.Tensor]],
]
):
"""Recombine the tensor data and metadata during backward pass."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Docstring omits key behavioral difference from restore_from_saved

The docstring for restore_from_func_ctx says only "Recombine the tensor data and metadata during backward pass," which is identical in meaning to restore_from_saved. The defining — and intentional — behavior of this new function (nullifying ctx.tensor_objects after the restore so memory can be freed) is not documented here, even though it's the primary motivation for adding the new API.

Suggested change
"""Recombine the tensor data and metadata during backward pass."""
"""Recombine the tensor data and metadata during backward pass.
Unlike `restore_from_saved`, this function deletes `ctx.tensor_objects`
after restoring (by setting it to None), which allows the reference to
the tensor objects to be released and the underlying memory to be freed
at the end of the current iteration rather than when the function context
is destroyed.
"""

@kainzhong kainzhong force-pushed the fix/change_restore_api branch from 5a03f56 to 4471dcb Compare March 17, 2026 21:47
…d from ctx

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
@kainzhong kainzhong force-pushed the fix/change_restore_api branch from 4471dcb to 3001bbd Compare March 17, 2026 21:49
Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ksivaman
Copy link
Member

/te-ci pytorch L0 L1

@ksivaman ksivaman merged commit 15760a5 into NVIDIA:main Mar 19, 2026
11 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants